{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image Classification for Coffee or Donuts\n",
    "\n",
    "Is the given image a coffee photo, or a donut?\n",
    "In this demo notebook we build a simple deep neural network to classify an image as one of the following:\n",
    "- coffee\n",
    "- mug\n",
    "- donut\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Prerequisites\n",
    "\n",
    "We use Keras/Tensorflow to build the classification model, and visualize the process with matplotlib."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install tensorflow==1.15 ibm-cos-sdk==2.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required libraries\n",
    "import os\n",
    "import uuid\n",
    "import shutil\n",
    "import json\n",
    "from botocore.client import Config\n",
    "import ibm_boto3\n",
    "import tensorflow as tf\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read The Data\n",
    "\n",
    "Here we build simple wrapper functions to read in the data from our cloud object storage buckets and extract it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @hidden_cell\n",
    "# The following code contains the credentials for a file in your IBM Cloud Object Storage.\n",
    "# You might want to remove those credentials before you share your notebook.\n",
    "credentials = {\n",
    "    'BUCKET': '$$$BUCKET$$$',\n",
    "    'IBM_API_KEY_ID': '$$$IBM_API_KEY_ID$$$',\n",
    "    'IAM_SERVICE_ID': '$$$IAM_SERVICE_ID$$$',\n",
    "    'ENDPOINT': '$$$ENDPOINT$$$',\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A download function from IBM Cloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_file_cos(credentials, local_file_name, key): \n",
    "    '''\n",
    "    Wrapper function to download a file from cloud object storage using the\n",
    "    credential dict provided and loading it into memory\n",
    "    '''\n",
    "    cos = ibm_boto3.client(service_name='s3',\n",
    "    ibm_api_key_id=credentials['IBM_API_KEY_ID'],\n",
    "    ibm_service_instance_id=credentials['IAM_SERVICE_ID'],\n",
    "    config=Config(signature_version='oauth'),\n",
    "    endpoint_url=credentials['ENDPOINT'])\n",
    "    try:\n",
    "        res=cos.download_file(Bucket=credentials['BUCKET'],Key=key,Filename=local_file_name)\n",
    "    except Exception as e:\n",
    "        print(Exception, e)\n",
    "    else:\n",
    "        print('File Downloaded')\n",
    "\n",
    "def get_annotations(credentials): \n",
    "    cos = ibm_boto3.client(\n",
    "        service_name='s3',\n",
    "        ibm_api_key_id=credentials['IBM_API_KEY_ID'],\n",
    "        ibm_service_instance_id=credentials['IAM_SERVICE_ID'],\n",
    "        config=Config(signature_version='oauth'),\n",
    "        endpoint_url=credentials['ENDPOINT']\n",
    "    )\n",
    "    try:\n",
    "        return json.loads(cos.get_object(Bucket=credentials['BUCKET'], Key='_annotations.json')['Body'].read())\n",
    "    except Exception as e:\n",
    "        print(Exception, e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = 'data'\n",
    "if os.path.exists(base_path) and os.path.isdir(base_path):\n",
    "    shutil.rmtree(base_path)\n",
    "os.makedirs(base_path, exist_ok=True)\n",
    "\n",
    "annotations = get_annotations(credentials)\n",
    "\n",
    "for i, image in enumerate(annotations['annotations'].keys()):\n",
    "    label = annotations['annotations'][image][0]['label']\n",
    "    os.makedirs(os.path.join(base_path, label), exist_ok=True)\n",
    "    _, extension = os.path.splitext(image)\n",
    "    local_path = os.path.join(base_path, label, str(uuid.uuid4()) + extension)\n",
    "    download_file_cos(credentials, local_path, image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!ls data/coffee\n",
    "!ls data/donut\n",
    "!ls data/mug"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the Model\n",
    "\n",
    "We start with a [MobileNetV2](https://arxiv.org/abs/1801.04381) architecture as the backbone [pretrained feature extractor](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet). We then add a couple of dense layers and a softmax layer to perfom the classification. We freeze the MobileNetV2 backbone with weights trained on ImageNet dataset and only train the dense layers and softmax layer that we have added."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model=tf.keras.applications.MobileNetV2(weights='imagenet',include_top=False) #imports the mobilenet model and discards the last 1000 neuron layer.\n",
    "x=base_model.output\n",
    "x=tf.keras.layers.GlobalAveragePooling2D()(x)\n",
    "x=tf.keras.layers.Dense(512,activation='relu')(x) #dense layer 1\n",
    "x=tf.keras.layers.Dense(256,activation='relu')(x) #dense layer 2\n",
    "preds=tf.keras.layers.Dense(3,activation='softmax')(x) #final layer with softmax activation\n",
    "\n",
    "model=tf.keras.Model(inputs=base_model.input,outputs=preds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Freeze layers from MobileNetV2 backbone (not to be trained)\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Prepare the training dataset as a data generator object\n",
    "train_datagen=tf.keras.preprocessing.image.ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input) #included in our dependencies\n",
    "\n",
    "train_generator=train_datagen.flow_from_directory('data',\n",
    "                                                 target_size=(224,224),\n",
    "                                                 color_mode='rgb',\n",
    "                                                 batch_size=10,\n",
    "                                                 class_mode='categorical',\n",
    "                                                 shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using Adam, categorical_crossentropy and accuracy as optimization method, loss function and metrics, respectively"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the model\n",
    "model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['accuracy'])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from tensorflow import set_random_seed\n",
    "set_random_seed(2)\n",
    "step_size_train=5\n",
    "log_file = model.fit_generator(generator=train_generator,\n",
    "                   steps_per_epoch=step_size_train,\n",
    "                   epochs=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure of Training Loss and Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model accuracy and loss vs epoch\n",
    "plt.plot(log_file.history['acc'], '-bo', label=\"train_accuracy\")\n",
    "plt.plot(log_file.history['loss'], '-r*', label=\"train_loss\")\n",
    "plt.title('Training Loss and Accuracy')\n",
    "plt.ylabel('Loss/Accuracy')\n",
    "plt.xlabel('Epoch #')\n",
    "plt.legend(loc='center right')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Performance\n",
    "\n",
    "Here we perform inference on some sample data points to determine the performance of the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mapping labels \n",
    "label_map = (train_generator.class_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating a sample inference function\n",
    "def prediction(image_path, model):\n",
    "    img = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224))\n",
    "    x = tf.keras.preprocessing.image.img_to_array(img)\n",
    "    x = np.expand_dims(x, axis=0)\n",
    "    x = tf.keras.applications.mobilenet_v2.preprocess_input(x)\n",
    "    preds = model.predict(x)\n",
    "    #print('Predictions', preds)\n",
    "    \n",
    "    for pred, value in label_map.items():    \n",
    "        if value == np.argmax(preds):\n",
    "            print('Predicted class is:', pred)\n",
    "            print('With a confidence score of: ', np.max(preds))\n",
    "    \n",
    "    return np.argmax(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: We should host the test images ourselves.\n",
    "coffee_url = 'https://i5.walmartimages.com/asr/559143bb-3fd0-41b0-8358-fb0f74c41f8d_1.61dbeaff765619bbba67d4e519174932.jpeg'\n",
    "mug_url = 'https://cdn.cnn.com/cnnnext/dam/assets/150929101049-black-coffee-stock-super-tease.jpg'\n",
    "donut_url = 'https://i.ytimg.com/vi/gevpzxRxec4/maxresdefault.jpg'\n",
    "!wget {coffee_url} -O Coffee.jpg \n",
    "!wget {mug_url} -O Mug.jpg\n",
    "!wget {donut_url} -O Donut.jpg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Opening first image\n",
    "image = Image.open(\"Coffee.jpg\")\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#performing inference on above image\n",
    "prediction('Coffee.jpg', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Opening second image\n",
    "image = Image.open(\"Mug.jpg\")\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction('Mug.jpg', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Opening third image\n",
    "image = Image.open(\"Donut.jpg\")\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction('Donut.jpg', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}